package database

import dsl.Query
import dsl.TableSchema
import query.select.NativeSelect
import query.select.Select
import query.select.ValuesSelect
import query.select.WithSelect
import util.getOutPutVisitor
import visitor.checkOLAP
import visitor.getExpr
import visitor.getQueryExpr
import java.sql.Connection
import javax.sql.DataSource

/**
 * 基础查询类
 * @property db DB 数据库类型
 * @property isTransaction Boolean 值为false，代表不是事务操作
 * @property dataSource DataSource 数据库连接池
 */
class DBConnection(source: DataSource, override val db: DB) : DataBaseImpl() {
    override val isTransaction: Boolean = false

    private var dataSource: DataSource = source

    /**
     * 创建一个事务
     * @param isolation Int? java.sql.Connection中定义的隔离级别（可以省略）
     * @param query [@kotlin.ExtensionFunctionType] Function1<DBTransaction, Unit> 事务操作lambda，可以传入若干个数据库操作，如果内部出错会回滚事务
     */
    inline fun transaction(isolation: Int? = null, query: DBTransaction.() -> Unit) {
        checkOLAP(this.db)

        val conn = getConnection()
        conn.autoCommit = false
        isolation?.let { conn.transactionIsolation = it }

        try {
            query(DBTransaction(this.db, conn))
            conn.commit()
        } catch (e: Exception) {
            e.printStackTrace()
            conn.rollback()
        } finally {
            conn.autoCommit = true
            conn.close()
        }
    }

    /**
     * 从连接池获取数据库连接
     * @return Connection 数据库连接
     */
    override fun getConnection(): Connection {
        return this.dataSource.connection
    }

    /**
     * 创建一个select查询
     * 例如：db.select("c1", "c2")
     * @param columns Array<out String> 字段名列表
     * @return Select 查询dsl
     */
    override fun select(vararg columns: String): Select {
        val select = Select(db, getConnection(), isTransaction, this)
        select.select(*columns)
        return select
    }

    /**
     * 创建一个select查询
     * 例如：db select count()
     * @param query Query 查询表达式
     * @return Select 查询dsl
     */
    override infix fun select(query: Query): Select {
        val select = Select(db, getConnection(), isTransaction, this)
        select.invoke(query)
        return select
    }

    /**
     * 创建一个select查询
     * 例如：db select listOf(count(), sum("c1"))
     * @param query List<Query> 查询表达式列表
     * @return Select 查询dsl
     */
    override infix fun select(query: List<Query>): Select {
        val select = Select(db, getConnection(), isTransaction, this)
        select.invoke(query)
        return select
    }

    /**
     * 创建一个select查询
     * 例如：db.select(count(), sum("c1"))
     * @param query Array<out Query> 查询表达式列表
     * @return Select 查询dsl
     */
    override fun select(vararg query: Query): Select {
        val select = Select(db, getConnection(), isTransaction, this)
        select.select(*query)
        return select
    }

    /**
     * 创建一个select查询
     * 例如：db.select()
     * @return Select 查询dsl
     */
    override fun select(): Select {
        return Select(db, getConnection(), isTransaction, this)
    }

    /**
     * 创建一个select查询
     * @param table String 表名
     * @return Select
     */
    override fun from(table: String): Select {
        return Select(db, getConnection(), isTransaction, this).from(table)
    }

    /**
     * 创建一个select查询
     * @param table TableSchema 实体类伴生对象名
     * @return Select
     */
    override fun from(table: TableSchema): Select {
        return Select(db, getConnection(), isTransaction, this).from(table)
    }

    /**
     * 创建一个原生sql查询
     * 例如：db.nativeSelect("select * from t1 where c1 = ?", 1)
     * @param sql String 查询sql语句
     * @param arg Array<out Any> 查询参数列表（可省略），查询语句中的?会被arg中的参数依次替换，合法的类型有Number、String、Date、List、Boolean以及null和Query表达式类型
     * @return NativeSelect 原生sql查询
     */
    override fun nativeSelect(sql: String, vararg arg: Any): NativeSelect {
        val argList = arg.map {
            if (it is Query) {
                val visitor = getOutPutVisitor(db)
                visitor.visitSqlExpr(getQueryExpr(it, db).expr)
                visitor.sql()
            } else {
                getExpr(it).toString()
            }
        }
        var nativeSql = sql
        if (sql.contains("?")) {
            argList.forEach {
                nativeSql = nativeSql.replaceFirst("?", it)
            }
        }
        return NativeSelect(db, nativeSql, getConnection(), isTransaction, this)
    }

    /**
     * 创建一个with查询
     * @return WithSelect with查询dsl
     */
    override fun with(): WithSelect {
        return WithSelect(db, getConnection(), isTransaction, this)
    }

    /**
     * 创建一个values查询
     * @param value Array<out List<Any>> value列表
     * @return ValuesSelect values查询dsl
     */
    override fun values(vararg value: List<Any>): ValuesSelect {
        val values = ValuesSelect(db, getConnection(), isTransaction, this)
        value.forEach {
            values.addRow(it)
        }
        return values
    }
}